# 1.输入图像预处理，包括尺寸，旋转。
# 2.真实值ground truth变形，shape = (w,h,kp_num) = (64, 64, n)
# 3.返回一个发生器，用于给模型做输入，以及输出时做损失计算。
import os
import numpy as np
import pandas as pd
import torch
from skimage import io, transform  # 用于图像的IO和变换
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from matplotlib import pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class TestDataSet(Dataset):
    """服装关键点标记数据集"""

    def __init__(self, csv_file, root_dir, num, transforms_img=None):
        """
        初始化数据集
        :param csv_file: 带标记的csv文件，为数据-category-标签coordination 成对组成的文件
        :param root_dir: 图像数据目录
        :param transform（callable,optional）: 一个样本上的可用可选变换
        """
        self.data_info = self.get_file_info(csv_file)
        self.root_dir = root_dir
        self.transform_img = transforms_img

    def __len__(self):
        return len(self.data_info[0])

    def __getitem__(self, idx):
        H, W = 64.0, 64.0
        img_id = self.data_info[idx]
        img_id = os.path.join(self.root_dir, img_id)
        image = io.imread(img_id)
        h, w, c = image.shape
        image = self.change_img_size(image, H, W)
        image = self.transform_img(image) / 255.0

        return image.float()

    @staticmethod
    def get_file_info(file_path):
        file_info = pd.read_csv(file_path)
        img_list = file_info.iloc[:, 0]
        return img_list

    @staticmethod
    def change_img_size(image, h, w):
        return transform.resize(image, (h, w))


class ToTensor(object):
    """将样本中的ndarrays转换为Tensors."""

    def __call__(self, sample):
        return torch.from_numpy(sample)


transform_img = transforms.Compose([
    transforms.ToTensor(),  # 将图像(Image)转成Tensor
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 将tensor标准化[-1,1]
])

dressDataset = TestDataSet(csv_file=r"..\\test\\test.csv",
                           root_dir=r"..\\test\\",
                           num=13,
                           transforms_img=transform_img,
                           )
dataloader = DataLoader(dataset=dressDataset, batch_size=8, shuffle=True)

if __name__ == "__main__":
    # test
    for i_batch, data in enumerate(dataloader):
        img = data
        # print(type(img), landmarks)
        if i_batch == 0:
            picture = img.numpy()
            maxPixle = picture.max()
            picture = picture * 255 / maxPixle
            mat = np.uint8(picture)

            plt.figure(num='dress')
            for i in range(8):
                plt.subplot(2, 4, i + 1)  # 将窗口分为两行两列四个子图，则可显示四幅图片
                plt.title(str(i + 1))  # 第一幅图片标题
                plt.imshow(mat[i].transpose(1, 2, 0))  # 绘制第一幅图片
            plt.show()
            print(img.shape)
            break
